B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"

def get_prompt_template(prompt_template_style="base", prompt_template_pth=None):
    response_prefix = None

    if prompt_template_style.startswith('math') or prompt_template_style.startswith('gsm8k') or prompt_template_style.startswith('prm800k') or prompt_template_style.startswith('mgsm'):
        PROMPT_TEMPLATE = "Question: %s\nAnswer:"
    else:
        raise ValueError("Invalid prompt template style.")
    
    if response_prefix is not None:
        PROMPT_TEMPLATE = PROMPT_TEMPLATE + " " + response_prefix
    
    return PROMPT_TEMPLATE

def apply_prompt_template(prompt_template_style='base', dataset=None, tokenizer=None, prefix="", return_dialogs=False):
    PROMPT_TEMPLATE = get_prompt_template(prompt_template_style)

    dialogs = []
    chats = []
    
    for prompt in dataset:
        prompt = PROMPT_TEMPLATE % (prefix + prompt)
        dialogs.append(prompt)
        chats.append(tokenizer.encode(prompt))
    
    if return_dialogs:
        return chats, dialogs
    else:
        return chats

def apply_icl_prompt(test_dataset_inputs, train_dataset_inputs, train_dataset_outputs, idx_mat, k, model_name, test_dataset, apply_chat_template):
    icl_prompts = []
    for i, prompt in enumerate(test_dataset_inputs):
        idx_list = idx_mat[i]
        icl_examples = [train_dataset_inputs[i] + " " + train_dataset_outputs[i] for i in idx_list]
        icl_prompt = "\n\n".join(icl_examples)
        
        if (test_dataset == "gsm8k-plus-mini" or test_dataset == "gsm8k-plus"):
            if("llama-2" in model_name.lower()):
                if (k>0):
                    icl_prompt_whole = B_INST + " " + icl_prompt + "\n\nYou need to solve the final question and give the final answer in the format: \n#### {result}\nIf you dont know the answer, please answer #### None.\n" + prompt + " " + E_INST
                elif k==0:
                    icl_prompt_whole = B_INST + " " + icl_prompt + "\n\nLet's think step by step. You need to solve the question and answer in the format: \n#### {result}\nIf you dont know the answer, please answer #### None.\n" + prompt + " " + E_INST
            else:
                if (k>0):
                    icl_prompt_whole = icl_prompt + "\n\nLet's think step by step. You need to solve the final question and answer in the format: \n#### {result}\nIf the information isn't enough to solve the question and there is no valid answer, please answer in the format: \n#### None\n" + prompt
                elif k==0:
                    icl_prompt_whole = icl_prompt + "\n\nLet's think step by step. You need to solve the question and answer in the format: \n#### {result}\nIf the information isn't enough to solve the question and there is no valid answer, please answer in the format: \n#### None\n" + prompt
        elif(test_dataset.startswith('math') or test_dataset.startswith("gsm8k") or test_dataset.startswith("prm800k") or test_dataset.startswith("mgsm")):
            if("llama-2" in model_name.lower()):
                if (k>0):
                    icl_prompt_whole = B_INST + " " + icl_prompt + "\n\nLet's think step by step. You need to solve the final question and give the final answer in the format: \n#### {result}\n" + prompt + " " + E_INST
                elif k==0:
                    icl_prompt_whole = B_INST + " " + icl_prompt + "\n\nLet's think step by step. You need to solve the question and answer in the format: \n#### {result}\n" + prompt + " " + E_INST
            elif("qwen" in model_name.lower() and 'math' in model_name.lower()):
                if (k>0):
                    icl_prompt_whole = icl_prompt + "\n\n" + prompt
                elif k==0:
                    icl_prompt_whole = icl_prompt + "\n\n" + prompt
            else:
                if (k>0):
                    icl_prompt_whole = icl_prompt + "\n\nLet's think step by step. You need to solve the final question and answer in the format: \n#### {result}\n" + prompt
                elif k==0:
                    icl_prompt_whole = icl_prompt + "\n\nLet's think step by step. You need to solve the question and answer in the format: \n#### {result}\n" + prompt
        
        icl_prompts.append(icl_prompt_whole)
    return icl_prompts